import torch
from torch.fx import GraphModule  # type: ignore
from torch.fx.symbolic_trace import Tracer  # type: ignore
from torch.fx.node import Target, Node, Argument  # type: ignore
from .fx import Fuser  # noqa: F401
from .fx import Quantizer  # noqa: F401
from .fx.utils import graph_pretty_str  # noqa: F401
from .fx.utils import get_custom_module_class_keys  # noqa: F401
from torch.nn.intrinsic import _FusedModule
from typing import Dict, Any, List, Callable, Tuple, Optional

def _check_is_graph_module(model: torch.nn.Module) -> None:
    if not isinstance(model, GraphModule):
        raise ValueError(
            'input model must be a GraphModule, ' +
            'Got type:' + str(type(model)) + ' Please make ' +
            'sure to follow the tutorials.')

def _swap_ff_with_fxff(model: torch.nn.Module) -> None:
    r""" Swap FloatFunctional with FXFloatFunctional
    """
    modules_to_swap = []
    for name, module in model.named_children():
        if isinstance(module, torch.nn.quantized.FloatFunctional):
            modules_to_swap.append(name)
        else:
            _swap_ff_with_fxff(module)

    for name in modules_to_swap:
        del model._modules[name]
        model._modules[name] = torch.nn.quantized.FXFloatFunctional()

def _fuse_fx(
        graph_module: GraphModule,
        fuse_custom_config_dict: Dict[str, Any] = None) -> GraphModule:
    r""" Internal helper function to fuse modules in preparation for quantization

    Args:
        graph_module: GraphModule object from symbolic tracing (torch.fx.symbolic_trace)
    """
    _check_is_graph_module(graph_module)
    fuser = Fuser()
    return fuser.fuse(graph_module, fuse_custom_config_dict)

class Scope(object):
    """ Scope object that records the module path and the module type
    of a module. Scope is used to track the information of the module
    that contains a Node in a Graph of GraphModule. For example:
    class Sub(torch.nn.Module):
        def forward(self, x):
            # This will be a call_method Node in GraphModule,
            # scope for this would be (module_path="sub", module_type=Sub)
            return x.transpose(1, 2)

    class M(torch.nn.Module):
        def __init__(self):
            self.sub = Sub()

        def forward(self, x):
            # This will be a call_method Node as well,
            # scope for this would be (module_path="", None)
            x = x.transpose(1, 2)
            x = self.sub(x)
            return x

    """
    def __init__(self, module_path: str, module_type: Any):
        super().__init__()
        self.module_path = module_path
        self.module_type = module_type

class ScopeContextManager(object):
    """ A context manager to track the Scope of Node during symbolic
    tracing.
    When entering a forward function of a Module, we'll update the scope information of
    the current module, and when we exit, we'll restore the previous scope information.
    """
    def __init__(
            self,
            scope: Scope,
            current_module: torch.nn.Module,
            current_module_path: str):
        super().__init__()
        self.prev_module_type = scope.module_type
        self.prev_module_path = scope.module_path
        self.scope = scope
        self.scope.module_path = current_module_path
        self.scope.module_type = type(current_module)

    def __enter__(self):
        return

    def __exit__(self, *args):
        self.scope.module_path = self.prev_module_path
        self.scope.module_type = self.prev_module_type
        return


class QuantizationTracer(Tracer):
    def __init__(
            self,
            skipped_module_names: List[str],
            skipped_module_classes: List[Callable]):
        super().__init__()
        self.skipped_module_names = skipped_module_names
        self.skipped_module_classes = skipped_module_classes
        # NB: initialized the module_type of top level module to None
        # we are assuming people won't configure the model with the type of top level
        # module here, since people can use "" for global config
        # We can change this if there is a use case that configures
        # qconfig using top level module type
        self.scope = Scope("", None)
        self.node_name_to_scope : Dict[str, Tuple[str, type]] = {}

    def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
        return (m.__module__.startswith("torch.nn") and
                not isinstance(m, torch.nn.Sequential)) or \
            module_qualified_name in self.skipped_module_names or \
            type(m) in self.skipped_module_classes or \
            isinstance(m, _FusedModule)

    def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], args : Tuple[Any, ...], kwargs : Dict[str, Any]) -> Any:
        module_qualified_name = self.path_of_module(m)
        # Creating scope with information of current module
        # scope will be restored automatically upon exit
        with ScopeContextManager(self.scope, m, module_qualified_name):
            return super().call_module(m, forward, args, kwargs)

    def create_node(self, kind : str, target : Target,
                    args : Tuple[Argument, ...], kwargs : Dict[str, Argument], name : Optional[str] = None,
                    type_expr : Optional[Any] = None) -> Node:
        node = super().create_node(kind, target, args, kwargs, name, type_expr)
        if kind == "call_method":
            self.node_name_to_scope[node.name] = (self.scope.module_path, self.scope.module_type)
        return node

def _prepare_fx(model: torch.nn.Module, qconfig_dict: Any,
                prepare_custom_config_dict: Dict[str, Any] = None,
                is_standalone_module: bool = False) -> GraphModule:
    r""" Internal helper function for prepare_fx
    Args:
      `model`, `qconfig_dict`, `prepare_custom_config_dict`: see docs for :func:`~torch.quantization.prepare_fx`
      `is_standalone_module`: a boolean flag indicates whether we are
      quantizing a standalone module or not, a standalone module
      is a submodule of the parent module that is not inlined in the
forward graph of the parent module,
      the way we quantize standalone module is described in:
      :func:`~torch.quantization._prepare_standalone_module_fx`
    """
    if prepare_custom_config_dict is None:
        prepare_custom_config_dict = {}

    skipped_module_names = prepare_custom_config_dict.get("non_traceable_module_name", [])
    skipped_module_classes = prepare_custom_config_dict.get("non_traceable_module_class", [])

    # swap FloatFunctional with FXFloatFunctional
    _swap_ff_with_fxff(model)

    # symbolically trace the model
    if not is_standalone_module:
        # standalone module and custom module config are applied in top level module
        standalone_module_name_configs = prepare_custom_config_dict.get("standalone_module_name", [])
        skipped_module_names += [config[0] for config in standalone_module_name_configs]

        standalone_module_class_configs = prepare_custom_config_dict.get("standalone_module_class", [])
        skipped_module_classes += [config[0] for config in standalone_module_class_configs]
        float_custom_module_classes = get_custom_module_class_keys(
            prepare_custom_config_dict, "float_to_observed_custom_module_class")
        skipped_module_classes += float_custom_module_classes
    tracer = QuantizationTracer(
        skipped_module_names, skipped_module_classes)
    graph_module = GraphModule(model, tracer.trace(model))
    graph_module = _fuse_fx(graph_module, prepare_custom_config_dict)
    quantizer = Quantizer()
    prepared = quantizer.prepare(
        graph_module,
        qconfig_dict,
        tracer.node_name_to_scope,
        prepare_custom_config_dict=prepare_custom_config_dict,
        is_standalone_module=is_standalone_module)

    preserved_attributes = prepare_custom_config_dict.get("preserved_attributes", [])
    for attr_name in preserved_attributes:
        setattr(prepared, attr_name, getattr(model, attr_name))
    return prepared

def _prepare_standalone_module_fx(
        model: torch.nn.Module,
        qconfig_dict: Any,
        prepare_custom_config_dict: Dict[str, Any] = None) -> GraphModule:
    r""" [Internal use only] Prepare a standalone module, so that it can be used when quantizing the
    parent module.
    standalone_module means it a submodule that is not inlined in parent module,
        and will be quantized separately as one unit.

    How the standalone module is observed is specified by `input_quantized_idxs` and
    `output_quantized_idxs` in the prepare_custom_config for the standalone module

    Returns:
        model(GraphModule): prepared standalone module
        attributes:
            _standalone_module_input_quantized_idxs(List[Int]): a list of
                indexes for the graph input that is expected to be quantized,
                same as input_quantized_idxs configuration provided
                for the standalone module
            _standalone_module_output_quantized_idxs(List[Int]): a list of
                indexs for the graph output that is quantized
                same as input_quantized_idxs configuration provided
                for the standalone module
    """
    return _prepare_fx(model, qconfig_dict, prepare_custom_config_dict, is_standalone_module=True)

def fuse_fx(model: torch.nn.Module,
            fuse_custom_config_dict: Dict[str, Any] = None) -> GraphModule:
    r""" Fuse modules like conv+bn, conv+bn+relu etc, model must be in eval mode.
    Fusion rules are defined in torch.quantization.fx.fusion_pattern.py
    Args:
        `model`: a torch.nn.Module model
        `fuse_custom_config_dict`: Dictionary for custom configurations for fuse_fx, e.g.
         fuse_custom_config_dict = {
           "additional_fuser_method_mapping": {
             (Module1, Module2): fuse_module1_module2
           }
         }

    Example:
    ```python
    from torch.quantization import fuse_fx
    m = Model().eval()
    m = fuse_fx(m)
    ```
    """
    torch._C._log_api_usage_once("quantization_api.quantize_fx.fuse_fx")
    assert not model.training, 'fuse_fx only works on models in eval mode'
    graph_module = torch.fx.symbolic_trace(model)  # type: ignore
    return _fuse_fx(graph_module, fuse_custom_config_dict)

def prepare_fx(
        model: torch.nn.Module, qconfig_dict: Any,
        prepare_custom_config_dict: Dict[str, Any] = None) -> GraphModule:
    r""" Prepare a model for post training static quantization

    Args:
      `model`: torch.nn.Module model, must be in eval mode
      `qconfig_dict`: qconfig_dict is a dictionary with the following configurations:
      qconfig_dict = {
      # optional, global config
      "": qconfig?,

      # optional, used for module and function types
      # could also be split into module_types and function_types if we prefer
      "object_type": [
        (torch.nn.Conv2d, qconfig?),
        (torch.nn.functional.add, qconfig?),
        ...,
       ],

      # optional, used for module names
      "module_name": [
        ("foo.bar", qconfig?)
        ...,
      ],

      # optional, matched in order, first match takes precedence
      "module_name_regex": [
        ("foo.*bar.*conv[0-9]+", qconfig?)
        ...,
      ],
      # priority (in increasing order): global, object_type, module_name_regex, module_name
      # qconfig == None means fusion and quantization should be skipped for anything
      # matching the rule
      }
      `prepare_custom_config_dict`: customization configuration dictionary for
      quantization tool:
      prepare_custom_config_dict = {
        # optional: specify the path for standalone modules
        # These modules are symbolically traced and quantized as one unit
        "standalone_module_name": [
           # module_name, qconfig_dict, prepare_custom_config_dict
           ("submodule.standalone",
            None,  # qconfig_dict for the prepare function called in the submodule,
                   # None means use qconfig from parent qconfig_dict
            {"input_quantized_idxs": [], "output_quantized_idxs": []})  # prepare_custom_config_dict
        ],

        "standalone_module_class": [
            # module_class, qconfig_dict, prepare_custom_config_dict
            (StandaloneModule,
             None,  # qconfig_dict for the prepare function called in the submodule,
                    # None means use qconfig from parent qconfig_dict
            {"input_quantized_idxs": [0], "output_quantized_idxs": [0]})  # prepare_custom_config_dict
        ],

        # user will manually define the corresponding observed
        # module class which has a from_float class method that converts
        # float custom module to observed custom module
        # (only needed for static quantization)
        "float_to_observed_custom_module_class": {
           "static": {
               CustomModule: ObservedCustomModule
           }
        },

        # the qualified names for the submodule that are not symbolically traceable
        "non_traceable_module_name": [
           "non_traceable_module"
        ],

        # the module classes that are not symbolically traceable
        # we'll also put dynamic/weight_only custom module here
        "non_traceable_module_class": [
           NonTraceableModule
        ],

        # Additional fuser_method mapping
        "additional_fuser_method_mapping": {
           (torch.nn.Conv2d, torch.nn.BatchNorm2d): fuse_conv_bn
        },

        # Additioanl module mapping for qat
        "additional_qat_module_mapping": {
           torch.nn.intrinsic.ConvBn2d: torch.nn.qat.ConvBn2d
        },

        # Additional fusion patterns
        "additional_fusion_pattern": {
           (torch.nn.BatchNorm2d, torch.nn.Conv2d): ConvReluFusionhandler
        },

        # Additional quantization patterns
        "additional_quant_pattern": {
           torch.nn.Conv2d: ConvReluQuantizeHandler,
           (torch.nn.ReLU, torch.nn.Conv2d): ConvReluQuantizeHandler,
        }

        # By default, inputs and outputs of the graph are assumed to be in
        # fp32. Providing `input_quantized_idxs` will set the inputs with the
        # corresponding indices to be quantized. Providing
        # `output_quantized_idxs` will set the outputs with the corresponding
        # indices to be quantized.
        "input_quantized_idxs": [0],
        "output_quantized_idxs": [0],

        # Attributes that are not used in forward function will
        # be removed when constructing GraphModule, this is a list of attributes
        # to preserve as an attribute of the GraphModule even when they are
        # not used in the code
        "preserved_attributes": ["preserved_attr"],
      }


    Return:
      A GraphModule with observer (configured by qconfig_dict), ready for calibration

    Example:
    ```python
    import torch
    from torch.quantization import get_default_qconfig
    from torch.quantization import prepare_fx

    float_model.eval()
    graph_module = torch.fx.symbolic_trace(float_model)
    qconfig = get_default_qconfig('fbgemm')
    def calibrate(model, data_loader):
        model.eval()
        with torch.no_grad():
            for image, target in data_loader:
                model(image)

    qconfig_dict = {"": qconfig}
    prepared_model = prepare_fx(graph_module, qconfig_dict)
    # Run calibration
    calibrate(prepared_model, sample_inference_data)
    ```
    """
    torch._C._log_api_usage_once("quantization_api.quantize_fx.prepare_fx")
    assert not model.training, 'prepare_fx only works for models in ' + \
        'eval mode'
    return _prepare_fx(model, qconfig_dict, prepare_custom_config_dict)

def prepare_qat_fx(
        model: torch.nn.Module, qconfig_dict: Any,
        prepare_custom_config_dict: Dict[str, Any] = None) -> GraphModule:
    r""" Prepare a model for quantization aware training
    Args:
      `model`: torch.nn.Module model, must be in train mode
      `qconfig_dict`: see :func:`~torch.quantization.prepare_fx`
      `prepare_custom_config_dict`: see :func:`~torch.quantization.prepare_fx`

    Return:
      A GraphModule with fake quant modules (configured by qconfig_dict), ready for
      quantization aware training

    Example:
    ```python
    import torch
    from torch.quantization import get_default_qat_qconfig
    from torch.quantization import prepare_fx

    qconfig = get_default_qat_qconfig('fbgemm')
    def train_loop(model, train_data):
        model.train()
        for image, target in data_loader:
            ...

    float_model.train()
    qconfig_dict = {"": qconfig}
    prepared_model = prepare_fx(float_model, qconfig_dict)
    # Run calibration
    train_loop(prepared_model, train_loop)
    ```
    """
    torch._C._log_api_usage_once("quantization_api.quantize_fx.prepare_qat_fx")
    assert model.training, 'prepare_qat_fx only works for models in  ' + \
        'train mode'
    return _prepare_fx(model, qconfig_dict, prepare_custom_config_dict)

def _convert_fx(
        graph_module: GraphModule, debug: bool,
        convert_custom_config_dict: Dict[str, Any] = None,
        is_standalone_module: bool = False) -> GraphModule:
    """ `is_standalone_module`: see docs in :func:`~torch.quantization.prepare_standalone_module_fx`
    """
    if convert_custom_config_dict is None:
        convert_custom_config_dict = {}

    _check_is_graph_module(graph_module)

    quantizer = Quantizer()
    quantized = quantizer.convert(graph_module, debug, convert_custom_config_dict, is_standalone_module)

    preserved_attributes = convert_custom_config_dict.get("preserved_attributes", [])
    for attr_name in preserved_attributes:
        setattr(quantized, attr_name, getattr(graph_module, attr_name))
    return quantized

def convert_fx(
        graph_module: GraphModule, debug: bool = False,
        convert_custom_config_dict: Dict[str, Any] = None) -> GraphModule:
    r""" Convert a calibrated or trained model to a quantized model
    Args:
        `graph_module`: A prepared and calibrated/trained model (GraphModule)
        `debug`: flag for producing a debug friendly model (preserve weight attribute)
        `convert_custom_config_dict`: dictionary for custom configurations for convert function:
        convert_custom_config_dict = {

          # addtional object (module/operator) mappings that will overwrite the default
          # module mappingn
          "additional_object_mapping": {
             "static": {
                FloatModule: QuantizedModule,
                float_op: quantized_op
             },
             "dynamic": {
                FloatModule: DynamicallyQuantizedModule,
                float_op: dynamically_quantized_op
             },
          },

          # user will manually define the corresponding quantized
          # module class which has a from_observed class method that converts
          # observed custom module to quantized custom module
          "observed_to_quantized_custom_module_class": {
             "static": {
                 ObservedCustomModule: QuantizedCustomModule
             },
             "dynamic": {
                 ObservedCustomModule: QuantizedCustomModule
             },
             "weight_only": {
                 ObservedCustomModule: QuantizedCustomModule
             }
          },

          # Attributes that are not used in forward function will
          # be removed when constructing GraphModule, this is a list of attributes
          # to preserve as an attribute of the GraphModule even when they are
          # not used in the code
          "preserved_attributes": ["preserved_attr"],
        }

    Return:
        A quantized model (GraphModule)

    Example:
    ```python
    # prepared_model: the model after prepare_fx/prepare_qat_fx and calibration/training
    quantized_model = convert_fx(prepared_model)
    ```
    """
    torch._C._log_api_usage_once("quantization_api.quantize_fx.convert_fx")
    return _convert_fx(graph_module, debug, convert_custom_config_dict)

def _convert_standalone_module_fx(
        graph_module: GraphModule, debug: bool = False,
        convert_custom_config_dict: Dict[str, Any] = None) -> GraphModule:
    r""" [Internal use only] Convert a model produced by :func:`~torch.quantization.prepare_standalone_module_fx`
    and convert it to a quantized model

    Returns a quantized standalone module, whether input/output is quantized is
    specified by prepare_custom_config_dict, with
    input_quantized_idxs, output_quantized_idxs, please
    see docs for prepare_fx for details
    """
    return _convert_fx(graph_module, debug, convert_custom_config_dict, is_standalone_module=True)
